rm(list = ls())
library(tidyverse)
library(broom)
library(haven)
library(polycor)
library(wCorr)
library(quanteda)
library(quanteda.textplots)
require(glmnet)
require(quanteda.textstats)
library(ggcorrplot)
library(tidymodels)
set.seed(221186)

weightedCorr_na <- function(x,y,weights,...){
  
  x_missing <- which(is.na(x))
  y_missing <- which(is.na(y))
  missing <- unique(c(x_missing, y_missing))
  
  weights <- weights[-missing]
  
  weightedCorr(x[-missing], y[-missing], weights = weights, ...)  
}


mae <- function(x, weights) weighted.mean(abs(x - mean(x, na.rm = T)), na.rm = TRUE, w = weights)
extreme <- function(x) sum(x%in%c(1,5))/sum(!is.na(x))
sd_narm <- function(x) sd(x, na.rm = TRUE)

n_boot_reps <- 500

load("../working/survey.Rdata")

## Combine wave 1 respondents with "top up" respondents from wave 2

vars <- c("trans_rights", "offensive_speech", "minimum_wage", "zero_hours", "unemployment_support", "high_tax")

get_effects_all <- function(split, wave = 2, type, threshold = NULL){
  cat(".")
  vars <- c("trans_rights", "offensive_speech", "minimum_wage", "zero_hours", "unemployment_support", "high_tax")
  
  if(wave == 2){
    
    for(i in 1:length(vars)) split[[vars[i]]] <- split[[paste0(vars,"_w2")[i]]]
    
  }
  
  if(!is.null(threshold)){
    
    split$trans_rights[split$treat == "Treatment" & split$duration_trans_text < threshold] <- NA
    split$offensive_speech[split$treat == "Treatment" & split$duration_trans_text < threshold] <- NA
    split$minimum_wage[split$treat == "Treatment" & split$duration_trans_text < threshold] <- NA
    split$zero_hours[split$treat == "Treatment" & split$duration_trans_text < threshold] <- NA
    split$unemployment_support[split$treat == "Treatment" & split$duration_trans_text < threshold] <- NA
    split$high_tax[split$treat == "Treatment" & split$duration_trans_text < threshold] <- NA
    
  }
  
  if(type == "mae") pol_func <- mae
  if(type == "sd") pol_func <- sd_narm
  if(type == "extreme") pol_func <- extreme
  
  split %>%
    group_by(treat) %>%
    summarise(across(all_of(vars), ~pol_func(.x))) %>%
    pivot_longer(-1) %>%
    pivot_wider(names_from = treat, values_from = value) %>%
    mutate(Diff = Treatment - Control) %>%
    mutate(name = tools::toTitleCase(gsub("_"," ", name)),
           name = tools::toTitleCase(gsub("-","/", name)),) %>%
    mutate(Diff_mean = sum(Diff)/6)
  
}


get_intervals <- function(split, wave = 1,...){
  
  ests <- get_effects_all(split, wave, ...)
  boot_out <- tibble(fx = lapply(1:n_boot_reps, function(x) get_effects_all(split[sample(1:nrow(split), replace = T),], wave = wave, ...)))
  
  ints <- boot_out %>% unnest(fx) %>%
    group_by(name) %>%
    summarise(high_treat = quantile(Treatment, 0.975),
              high_control = quantile(Control, 0.975),
              lo_treat = quantile(Treatment, 0.025),
              lo_control = quantile(Control, 0.025),
              high = quantile(Diff, 0.975),
              low = quantile(Diff, 0.025))
  
  average_ests <- boot_out %>% mutate(id = 1:length(fx)) %>% unnest(fx) %>% group_by(id) %>% 
    summarise(Diff_mean = unique(Diff_mean)) %>% ungroup() %>% 
    summarise(Diff_low = quantile(Diff_mean, 0.025),
              Diff_high = quantile(Diff_mean, 0.975))
  
  ests$Diff_low <- average_ests$Diff_low
  ests$Diff_high <- average_ests$Diff_high
  
  ests <- full_join(ests, ints, by = "name")
  
  ests$est <- ests$Diff
  
  ests %>% select(-c(Diff))
  
}

## Estimate main effects 

polarization_effects_w1 <- get_intervals(just_w1, 1, type = "mae")
polarization_effects_panel <- get_intervals(just_panel, 2, type = "mae")
polarization_effects_topup <- get_intervals(just_topup, 2, type = "mae")
polarization_effects_combined <- get_intervals(just_w1_topup_combined, 2, type = "mae")

## Estimate attrition-weighted main effects 

polarization_effects_w1_attrition <- just_w1 %>% 
  mutate(weight = weight * weight_attrition_in_wave) %>% 
  get_intervals(1, type = "mae")

polarization_effects_panel_attrition <- just_panel %>% 
  mutate(weight = weight * weight_attrition_panel) %>% 
  get_intervals(2, type = "mae")

polarization_effects_topup_attrition <- just_topup %>% 
  mutate(weight = weight * weight_attrition_in_wave) %>% 
  get_intervals(2, type = "mae")

## Plot main effects

polarization_effects_w1$wave <- "Sample One, Wave One"
polarization_effects_panel$wave <- "Sample One, Wave Two"
polarization_effects_topup$wave <- "Sample Two"
polarization_effects_combined$wave <- c("Combined Sample")

average_treat_effects <- data.frame(wave = c("Sample One, Wave One", "Sample Two", "Sample One, Wave Two"),
                                    est = c(unique(polarization_effects_w1$Diff_mean), 
                                            unique(polarization_effects_topup$Diff_mean), 
                                            unique(polarization_effects_panel$Diff_mean)),
                                    low = c(unique(polarization_effects_w1$Diff_low), 
                                            unique(polarization_effects_topup$Diff_low), 
                                            unique(polarization_effects_panel$Diff_low)),
                                    high = c(unique(polarization_effects_w1$Diff_high), 
                                             unique(polarization_effects_topup$Diff_high), 
                                             unique(polarization_effects_panel$Diff_high)),
                                    name = polarization_effects_topup$name[1]) %>% 
  mutate(wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two")))

bind_rows(polarization_effects_w1, polarization_effects_panel, polarization_effects_topup) %>%
  mutate(name = factor(name, levels = polarization_effects_w1$name[order(polarization_effects_w1$est)]),
         wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two"))) %>%
  ggplot(aes(x = est, xmin = low, xmax = high, y = name, col = wave, fill = wave)) + 
  geom_pointrange(position = position_dodge(width = .2)) + 
  facet_wrap(~wave, ncol = 4) + 
  geom_rect(data = average_treat_effects, ymax = 100, ymin = -100, alpha = .2, lwd = 0.0001) + 
  geom_vline(aes(xintercept = est, col = wave), data = average_treat_effects) + 
  theme_bw() + 
  xlab("Effect of reason-giving on polarization") + 
  ylab("") + 
  theme(#panel.background = element_rect(fill = "transparent"),
        #plot.background = element_rect(fill = "transparent", color = NA),
        #legend.background = element_rect(fill = "transparent", color = NA),
        #legend.box.background = element_rect(fill = "transparent", color = NA),
        legend.key = element_blank(),
        legend.position = "bottom") +
  geom_vline(xintercept = 0, linetype = 2) + 
  scale_color_manual("", values = c("black", "black", "black", "black")) + 
  scale_fill_manual("", values = c("black", "black", "black", "black")) + 
  guides(fill = "none", color = "none") 

ggsave("../out/outcomes/polarization.pdf", width = 7, height = 3)
ggsave("../out/outcomes/polarization.png", dpi = 600, width = 7, height = 3)

bind_rows(polarization_effects_w1, polarization_effects_panel, polarization_effects_topup) %>%
  select(name, Control, Treatment, high_treat, high_control, lo_treat, lo_control, wave) %>%
  rename(est_control = Control,
         est_treat = Treatment) %>%
  pivot_longer(cols = -c(name, wave),
               names_to = c(".value", "Var"),
               names_sep = "_") %>%
  mutate(name = factor(name, levels = polarization_effects_w1$name[order(polarization_effects_w1$est)]),
         wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two")),
         Var = case_when(Var == "control" ~ "Control",
                         Var == "treat" ~ "Treatment")) %>%
  ggplot(aes(x = est, xmin = lo, xmax = high, y = name, col = Var, fill = Var)) + 
  geom_pointrange(position = position_dodge(width = .2)) + 
  facet_wrap(~wave, ncol = 4) + 
  theme_bw() + 
  xlab("Mean absolute error") + 
  ylab("") + 
  theme(panel.background = element_rect(fill = "transparent"),
        plot.background = element_rect(fill = "transparent", color = NA),
        legend.background = element_rect(fill = "transparent", color = NA),
        legend.box.background = element_rect(fill = "transparent", color = NA),
        legend.key = element_blank(),
        legend.position = "bottom") +
  scale_color_manual("", values = c("black", "gray")) + 
  scale_fill_manual("", values = c("black", "gray")) 

ggsave("../out/outcomes/polarization_levels.pdf", width = 12, height = 4)

## Plot attrition-weighted main effects

polarization_effects_w1_attrition$wave <- "Sample One, Wave One"
polarization_effects_panel_attrition$wave <- "Sample One, Wave Two"
polarization_effects_topup_attrition$wave <- "Sample Two"

average_treat_effects <- data.frame(wave = c("Sample One, Wave One", "Sample Two", "Sample One, Wave Two"),
                                    est = c(unique(polarization_effects_w1_attrition$Diff_mean), 
                                            unique(polarization_effects_topup_attrition$Diff_mean), 
                                            unique(polarization_effects_panel_attrition$Diff_mean)),
                                    low = c(unique(polarization_effects_w1_attrition$Diff_low), 
                                            unique(polarization_effects_topup_attrition$Diff_low), 
                                            unique(polarization_effects_panel_attrition$Diff_low)),
                                    high = c(unique(polarization_effects_w1_attrition$Diff_high), 
                                             unique(polarization_effects_topup_attrition$Diff_high), 
                                             unique(polarization_effects_panel_attrition$Diff_high)),
                                    name = polarization_effects_topup_attrition$name[1]) %>% 
  mutate(wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two")))

bind_rows(polarization_effects_w1_attrition, polarization_effects_panel_attrition, polarization_effects_topup_attrition) %>%
  mutate(name = factor(name, levels = polarization_effects_w1_attrition$name[order(polarization_effects_w1_attrition$est)]),
         wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two"))) %>%
  ggplot(aes(x = est, xmin = low, xmax = high, y = name, col = wave, fill = wave)) + 
  geom_pointrange(position = position_dodge(width = .2)) + 
  facet_wrap(~wave, ncol = 4) + 
  geom_rect(data = average_treat_effects, ymax = 100, ymin = -100, alpha = .2, lwd = 0.0001) + 
  geom_vline(aes(xintercept = est, col = wave), data = average_treat_effects) + 
  theme_bw() + 
  xlab("Effect of reason-giving on polarization") + 
  ylab("") + 
  theme(panel.background = element_rect(fill = "transparent"),
        plot.background = element_rect(fill = "transparent", color = NA),
        legend.background = element_rect(fill = "transparent", color = NA),
        legend.box.background = element_rect(fill = "transparent", color = NA),
        legend.key = element_blank(),
        legend.position = "bottom") +
  geom_vline(xintercept = 0, linetype = 2) + 
  scale_color_manual("", values = c("black", "black", "black", "black")) + 
  scale_fill_manual("", values = c("black", "black", "black", "black")) + 
  guides(fill = "none", color = "none") 

ggsave("../out/outcomes/polarization_attrition.pdf", width = 7, height = 3)

polarization_effects_est <- list(polarization_effects_w1 = polarization_effects_w1,
                               polarization_effects_panel = polarization_effects_panel,
                               polarization_effects_topup = polarization_effects_topup,
                               polarization_effects_combined = polarization_effects_combined,
                               average_treat_effects = average_treat_effects)

save(polarization_effects_est, file = "../working/polarization_effects_est.Rdata")

# Heterogeneous effects

# Combine wave 1 responses with top-up responses in wave 2

just_topup <- just_topup %>% 
  mutate(high_tax = high_tax_w2,
         zero_hours = zero_hours_w2,
         unemployment_support = unemployment_support_w2,
         trans_rights = trans_rights_w2,
         offensive_speech = offensive_speech_w2,
         minimum_wage = minimum_wage_w2)

tmp <- bind_rows(just_topup, just_w1)

polarization_effects_attention <- lapply(c("High", "Low"), function(x) get_intervals(tmp[tmp$attention_cat2 == x,], 1, type = "mae"))
names(polarization_effects_attention) <- c("High", "Low")

polarization_effects_gender <- lapply(c("Male", "Female"), function(x) get_intervals(tmp[tmp$gender == x,], 1, type = "mae"))
names(polarization_effects_gender) <- c("Male", "Female")

polarization_effects_age <- lapply(sort(unique(tmp$age_cat)), function(x) get_intervals(tmp[tmp$age_cat == x,], 1, type = "mae"))
names(polarization_effects_age) <- sort(unique(tmp$age_cat))

polarization_effects_educ <- lapply(sort(unique(tmp$education.x)), function(x) get_intervals(tmp[tmp$education.x == x,], 1, type = "mae"))
names(polarization_effects_educ) <- sort(unique(tmp$education.x))

polarization_effects_eu_vote <- lapply(sort(unique(tmp$eu_vote)), function(x) get_intervals(tmp[tmp$eu_vote == x,], 1, type = "mae"))
names(polarization_effects_eu_vote) <- sort(unique(tmp$eu_vote))

polarization_effects_vote19 <- lapply(sort(unique(tmp$vote19_simple)), function(x) get_intervals(tmp[tmp$vote19_simple == x,], 1, type = "mae"))
names(polarization_effects_vote19) <- sort(unique(tmp$vote19_simple))

polarization_het_effects <- list(attention = polarization_effects_attention,
                                 gender = polarization_effects_gender,
                                 age = polarization_effects_age,
                                 education = polarization_effects_educ,
                                 eu_vote = polarization_effects_eu_vote,
                                 vote19 = polarization_effects_vote19)

save(polarization_het_effects, file = "../working/het_effects_polarization.Rdata")


# Alternative measures of polarization

## Estimate main effects -- proportion of extreme responses

polarization_effects_w1 <- get_intervals(just_w1, 1, type = "extreme")
polarization_effects_panel <- get_intervals(just_panel, 2, type = "extreme")
polarization_effects_topup <- get_intervals(just_topup, 2, type = "extreme")
polarization_effects_combined <- get_intervals(just_w1_topup_combined, 2, type = "extreme")

polarization_effects_w1$wave <- "Sample One, Wave One"
polarization_effects_panel$wave <- "Sample One, Wave Two"
polarization_effects_topup$wave <- "Sample Two"
polarization_effects_combined$wave <- c("Combined Sample")

average_treat_effects <- data.frame(wave = c("Sample One, Wave One", "Sample Two", "Sample One, Wave Two"),
                                    est = c(unique(polarization_effects_w1$Diff_mean), 
                                            unique(polarization_effects_topup$Diff_mean), 
                                            unique(polarization_effects_panel$Diff_mean)),
                                    low = c(unique(polarization_effects_w1$Diff_low), 
                                            unique(polarization_effects_topup$Diff_low), 
                                            unique(polarization_effects_panel$Diff_low)),
                                    high = c(unique(polarization_effects_w1$Diff_high), 
                                             unique(polarization_effects_topup$Diff_high), 
                                             unique(polarization_effects_panel$Diff_high)),
                                    name = polarization_effects_topup$name[1]) %>% 
  mutate(wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two")))

bind_rows(polarization_effects_w1, polarization_effects_panel, polarization_effects_topup) %>%
  mutate(name = factor(name, levels = polarization_effects_w1$name[order(polarization_effects_w1$est)]),
         wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two"))) %>%
  ggplot(aes(x = est, xmin = low, xmax = high, y = name, col = wave, fill = wave)) + 
  geom_pointrange(position = position_dodge(width = .2)) + 
  facet_wrap(~wave, ncol = 4) + 
  geom_rect(data = average_treat_effects, ymax = 100, ymin = -100, alpha = .2, lwd = 0.0001) + 
  geom_vline(aes(xintercept = est, col = wave), data = average_treat_effects) + 
  theme_bw() + 
  xlab("Effect of reason-giving on polarization") + 
  ylab("") + 
  theme(panel.background = element_rect(fill = "transparent"),
        plot.background = element_rect(fill = "transparent", color = NA),
        legend.background = element_rect(fill = "transparent", color = NA),
        legend.box.background = element_rect(fill = "transparent", color = NA),
        legend.key = element_blank(),
        legend.position = "bottom") +
  geom_vline(xintercept = 0, linetype = 2) + 
  scale_color_manual("", values = c("black", "black", "black", "black")) + 
  scale_fill_manual("", values = c("black", "black", "black", "black")) + 
  guides(fill = "none", color = "none") 

ggsave("../out/outcomes/polarization_extreme.pdf", width = 7, height = 3)


## Estimate main effects -- standard deviation

polarization_effects_w1 <- get_intervals(just_w1, 1, type = "sd")
polarization_effects_panel <- get_intervals(just_panel, 2, type = "sd")
polarization_effects_topup <- get_intervals(just_topup, 2, type = "sd")
polarization_effects_combined <- get_intervals(just_w1_topup_combined, 2, type = "sd")

polarization_effects_w1$wave <- "Sample One, Wave One"
polarization_effects_panel$wave <- "Sample One, Wave Two"
polarization_effects_topup$wave <- "Sample Two"
polarization_effects_combined$wave <- c("Combined Sample")

average_treat_effects <- data.frame(wave = c("Sample One, Wave One", "Sample Two", "Sample One, Wave Two"),
                                    est = c(unique(polarization_effects_w1$Diff_mean), 
                                            unique(polarization_effects_topup$Diff_mean), 
                                            unique(polarization_effects_panel$Diff_mean)),
                                    low = c(unique(polarization_effects_w1$Diff_low), 
                                            unique(polarization_effects_topup$Diff_low), 
                                            unique(polarization_effects_panel$Diff_low)),
                                    high = c(unique(polarization_effects_w1$Diff_high), 
                                             unique(polarization_effects_topup$Diff_high), 
                                             unique(polarization_effects_panel$Diff_high)),
                                    name = polarization_effects_topup$name[1]) %>% 
  mutate(wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two")))

bind_rows(polarization_effects_w1, polarization_effects_panel, polarization_effects_topup) %>%
  mutate(name = factor(name, levels = polarization_effects_w1$name[order(polarization_effects_w1$est)]),
         wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two", "Sample Two"))) %>%
  ggplot(aes(x = est, xmin = low, xmax = high, y = name, col = wave, fill = wave)) + 
  geom_pointrange(position = position_dodge(width = .2)) + 
  facet_wrap(~wave, ncol = 4) + 
  geom_rect(data = average_treat_effects, ymax = 100, ymin = -100, alpha = .2, lwd = 0.0001) + 
  geom_vline(aes(xintercept = est, col = wave), data = average_treat_effects) + 
  theme_bw() + 
  xlab("Effect of reason-giving on polarization") + 
  ylab("") + 
  theme(panel.background = element_rect(fill = "transparent"),
        plot.background = element_rect(fill = "transparent", color = NA),
        legend.background = element_rect(fill = "transparent", color = NA),
        legend.box.background = element_rect(fill = "transparent", color = NA),
        legend.key = element_blank(),
        legend.position = "bottom") +
  geom_vline(xintercept = 0, linetype = 2) + 
  scale_color_manual("", values = c("black", "black", "black", "black")) + 
  scale_fill_manual("", values = c("black", "black", "black", "black")) + 
  guides(fill = "none", color = "none") 

ggsave("../out/outcomes/polarization_sd.pdf", width = 7, height = 3)

## Estimate main effects with duration threshold

polarization_effects_w1_thresh <- get_intervals(just_w1, 1, type = "mae", threshold = 30)
polarization_effects_panel_thresh <- get_intervals(just_panel, 2, type = "mae", threshold = 30)

polarization_average_treat_effects_thresh <- data.frame(wave = c("Sample One, Wave One", "Sample One, Wave Two"),
                                    est = c(unique(polarization_effects_w1_thresh$Diff_mean), 
                                            unique(polarization_effects_panel_thresh$Diff_mean)),
                                    low = c(unique(polarization_effects_w1_thresh$Diff_low), 
                                            unique(polarization_effects_panel_thresh$Diff_low)),
                                    high = c(unique(polarization_effects_w1_thresh$Diff_high), 
                                             unique(polarization_effects_panel_thresh$Diff_high))) %>% 
  mutate(wave = factor(wave, levels = c("Sample One, Wave One", "Sample One, Wave Two")))

save(polarization_effects_w1_thresh, polarization_effects_panel_thresh, polarization_average_treat_effects_thresh, file = "../working/dur_effects_polarization.Rdata")
